import os
import torch
import numpy as np
from torch import nn
from torchvision import transforms
from cvx2 import WidthBlock, DOWidthBlock
from cvx2.wrapper import ImageClassifyModelWrapper

data_dir = '/Users/summy/data/AI-face2'
# data_dir = '/Users/summy/Downloads/分子分型'

if __name__ == '__main__':
	model = nn.Sequential(
		WidthBlock(c1=3, c2=32),
		nn.MaxPool2d(kernel_size=2, stride=2),
		WidthBlock(c1=32, c2=32),
		nn.MaxPool2d(kernel_size=2, stride=2),
		nn.Flatten(),
		nn.Linear(in_features=32 * 49, out_features=1024),
		nn.Dropout(0.2),
		nn.SiLU(inplace=True),
		nn.Linear(in_features=1024, out_features=2),
	)
	# 98.24%
	# model = nn.Sequential(
	# 	DOWidthBlock(c1=3, c2=32),
	# 	nn.MaxPool2d(kernel_size=2, stride=2),
	# 	DOWidthBlock(c1=32, c2=32),
	# 	nn.MaxPool2d(kernel_size=2, stride=2),
	# 	nn.Flatten(),
	# 	nn.Linear(in_features=32 * 49, out_features=1024),
	# 	nn.Dropout(0.2),
	# 	nn.SiLU(inplace=True),
	# 	nn.Linear(in_features=1024, out_features=10),
	# )
	
	transform = transforms.Compose([
		transforms.Resize((28, 28)),  # 将所有图片resize到28x28
		transforms.ToTensor()
	])
	
	wrapper = ImageClassifyModelWrapper(model)
	wrapper.train_evaluate(data=data_dir, imgsz=28, epochs=1, monitor='val_loss')
	# wrapper.save(save_mode='best')
	wrapper.evaluate(data=os.path.join(data_dir, 'test'), imgsz=28)
	# print(wrapper.predict(os.path.join(data_dir, 'val', 'zebra crossing'), imgsz=28))
	# print(wrapper.predict_classes(os.path.join(data_dir, 'val', 'zebra crossing'), imgsz=28))
	# print(wrapper.predict_classes_proba(os.path.join(data_dir, 'val', 'zebra crossing'), imgsz=28))
	